#!/usr/bin/env python
# -*- coding: utf-8 -*-
#from __future__ import division, with_statement
'''
Copyright 2013, 陈同 (chentong_biology@163.com).  
===========================================================
'''
__author__ = 'chentong & ct586[9]'
__author_email__ = 'chentong_biology@163.com'
#=========================================================
desc = '''
Functional description:

    This program is designed to find the coordinates of given mRNA
    fragments.

    It requires three files (detailed context see below)
        1. Full mRNA sequence in FASTA format
        2. The coordinate of each mRNA nucleotide in FASTA similar
        format. (Generated by mapPostionsOfNtForEachTranscript.py)
        3. The fragments you want to locate in FASTA format with a
        label to indicate its original transcript.

mRNA.fa:

>NM_001011874@Xkr4
gcggcggcgggcgagcgggcgctggagtaggagctggggagcggcgcggccggggaaggaagccagggcgaggcgaggaggtggcgggaggaggagacagcagggacaggTGTCAGATAAAGGAGTGCTCTCCTCCGCTGCCGAGGCATCATGGCCGCTAAGTCAGACGGGAGGCTGAAGATGAAGAAGAGCAGCGACG

coordinate.fa:

>NM_008390@Irf1
chr11,+,53583515,53583516,53583517,53583518,53583519,53583520,53583521,53583522,53583523,53583524,53583525,53583526,53583527,53583528,53583529,53583530,53583531,53583532,53583533,53583534,53583535

frag.fa: (The string between the first and second well symbol(#)
should match the FATSA name in mRNA.fa and coordinate.fa)
    
>mmu-miR-7036-5p#NM_177898@Nek5#1#133.00@-20.40
tggatgtgCCAGGGACCCCGTg
>mmu-miR-7036-5p#NM_177898@Nek5#2#126.00@-23.15
catgcTCACTGCTTAACGCCCGCc
>mmu-miR-7036-5p#NM_177898@Nek5#3#126.00@-27.92
ctgtcaCCCCACCACCCCTGCa


'''

import sys
import os
from time import localtime, strftime 
timeformat = "%Y-%m-%d %H:%M:%S"
from optparse import OptionParser as OP

#from mapPostionsOfNtForEachTranscript import readFasta





def cmdparameter(argv):
    if len(argv) == 1:
        global desc
        print >>sys.stderr, desc
        cmd = 'python ' + argv[0] + ' -h'
        os.system(cmd)
        sys.exit(1)
    usages = "%prog -i file"
    parser = OP(usage=usages)
    parser.add_option("-i", "--input-file", dest="filein",
        metavar="FILEIN", help="FASTA file containing short \
fragments.")
    parser.add_option("-t", "--transcript", dest="transcript",
        help="The transcript file in FASTA format.")
    parser.add_option("-c", "--coordinate", dest="coordinate",
        help="The coordinate file in FASTA format.")
    parser.add_option("-s", "--small-frag", dest="small_frag",
        default=0, help="A value to indicate the amount of frags \
in file given to <-i>. This parameter will not affect the result \
but right parameter may promote processing speed. When there are \
more frags than references, <0> should be given here \
(which is default). Otherwise <1> should be given. ")
    parser.add_option("-v", "--verbose", dest="verbose",
        default=0, help="Show process information")
    parser.add_option("-d", "--debug", dest="debug",
        default=False, help="Debug the program")
    (options, args) = parser.parse_args(argv[1:])
    assert options.filein != None, "A filename needed for -i"
    return (options, args)
#--------------------------------------------------------------------

def readFasta(fh, idD=''):
    seqDict = {}
    for line in fh:
        if line[0] == '>':
            read = 1
            locus = line.strip()[1:]
            if idD and locus not in idD:
                read = 0
        elif read:
            if locus not in seqDict:
                seqDict[locus] = [line.strip()]
            else:
                seqDict[locus].append(line.strip())
        #------------------------------------------
    #--------------------------------------------
    for key, valueL in seqDict.items():
        seqDict[key] = ''.join(valueL)
    return seqDict
#--------------------------------------------

def readCoordinates(coord, idD=''):
    '''
    >Sequence_name
    chr1,-,3,2,1,...............coords..........
    >Sequence_name
    chr1,+,1,2,3,...............coords..........
    '''
    coordDict = {}
    for line in open(coord):
        if line[0] == '>':
            read = 1
            key = line.strip()[1:]
            if idD and key not in idD:
                read = 0
        elif read:
            lineL = line.strip().split(',')
            chr = lineL[0]
            strand = lineL[1]
            coord_list = [int(i) for i in lineL[2:]]
            assert key not in coordDict
            coordDict[key] = [chr, strand, coord_list] 
    return coordDict
#--------------------------------------------------


#def transferListToRegion(coord_list, order=1):
#    #Default we assume coord_list is sorted numerically
#    if order:
#        if coord_list[0] > coord_list[1]:
#            coord_list.reverse()
#    else:
#        coord_list.sort()
#    #merged bins in initialMergedBin
#    initialMergedBin = []
#    binarySearch(coord_list, len(coord_list), initialMergedBin)
#    #Merge continuous bins one by one
#    mergedBin = []
#    start = initialMergedBin[0][0]
#    end   = initialMergedBin[0][1]
#    for i in initialMergedBin[1:]:
#        if i[0] == end:
#            end = i[1]
#        else:
#            mergedBin.append((start, end))
#            start = i[0]
#            end   = i[1]
#    #--add the last bin---------------
#    mergedBin.append((start, i[1]))
#    #-------Finish all-------------------------
#    return mergedBin
##---------------------------------------------
#
#def binarySearch(coord_list, length, tmpL):
#    '''
#    This uses binary search to find continuous regions.
#
#    The coord_list should in numerical order from small to big.
#    '''
#    start = coord_list[0]
#    end   = coord_list[-1] + 1
#    assert start < end, "Wrong order of coord_list"
#    if end - start == length:
#        tmpL.extend([(start, end)])
#        return tmpL
#    else:
#        half = length / 2
#        binarySearch(coord_list[:half],half,tmpL)
#        binarySearch(coord_list[half:],length-half,tmpL)
#    #----------------------------------------
##---------------------------------------------

def findAllsubstr(substr, str):
    '''
    posL = [[start, end], [start, end]]
    '''
    posL = []
    start = 0
    len_sub = len(substr)
    pos = str.find(substr, start)
    while pos != -1:
        posL.append([pos, pos+len_sub])
        start = pos + 1
        pos = str.find(substr, start)
    #--------------------------
    return posL
#---------------------------
def transferListToRegion(coord_list, order=1):
    #Default we assume coord_list is sorted numerically
    if order:
        if coord_list[0] > coord_list[1]:
            coord_list.reverse()
    else:
        coord_list.sort()
    #----------------------------------------
    regionL = []
    len_coord_list = len(coord_list)
    start = coord_list[0]
    for pos in range(1, len_coord_list):
        if coord_list[pos] - coord_list[pos-1] > 1:
            end = coord_list[pos-1] + 1
            regionL.append([start, end])
            start = coord_list[pos]
    #----The last one-------------
    regionL.append([start, coord_list[-1]+1])
    #print coord_list
    #print regionL
    return regionL
#--------------------------------------------

def posToCoord(posL, coordL):
    '''
    coordL = [chr, strand, coord_list] 
    multipleRegionL = [
        [[], []] #coordinate for one matched position
    ]
    '''
    multipleRegionL = []
    posToCoordL = []
    coord_list = coordL[2]
    for start, end in posL:
        tmpCoordL = coord_list[start:end]
        multipleRegionL.append(transferListToRegion(tmpCoordL))
    return multipleRegionL
#---------------------------------

def transferFragToCoord(frag, refS, coordL):
    fragPosL = findAllsubstr(frag, refS)
    if fragPosL:
        return posToCoord(fragPosL, coordL)
    else:
        return 'unmap'
#-------------------------------------

def output(label, coordL, multipleRegionL, fragS):
    '''
    coordL = [chr, strand, coord_list] 
    multipleRegionL = [
        [[], []] #coordinate for one matched position
    ]
    '''
    if multipleRegionL == 'unmap':
        print "%s unmapped" % label
        return
    count = 0
    for singleRegionL in multipleRegionL:
        count += 1
        for region in singleRegionL:
            print "%s\t%d\t%d\t%s#%d\t%s\t%s" % \
                (coordL[0], region[0], region[1], label, count, fragS, coordL[1])

#---------------------------------------------


def main():
    options, args = cmdparameter(sys.argv)
    #-----------------------------------
    file = options.filein
    transcript = options.transcript
    coordinate = options.coordinate
    small_frag = options.small_frag
    verbose = options.verbose
    debug = options.debug
    #-----------------------------------
    if file == '-':
        fh = sys.stdin
    else:
        fh = open(file)
    #--------------------------------
    if small_frag:
        fragD = readFasta(fh)
        geneD = {}
        for fragN in fragD.keys():
            label = fragN.split('#')[1]
            geneD[label] = 1
        #-----------------------
        fh2 = open(transcript)
        transcriptD = readFasta(fh2, geneD)
        fh2.close()
        coordDict = readCoordinates(coordinate, geneD)
        
        '''
        coordL = [chr, strand, coord_list] 
        multipleRegionL = [
            [[], []] #coordinate for one matched position
        ]
        '''
        for fragN, fragS in fragD.items():
            label = fragN.split('#')[1]
            refS = transcriptD[label]
            coordL = coordDict[label]
            multipleRegionL = transferFragToCoord(fragS, refS, coordL)
            output(fragN, coordL, multipleRegionL, fragS)
        #----------------------------------------------
    #---------above for smallFrag----------------------------
    else:
        fh2 = open(transcript)
        transcriptD = readFasta(fh2)
        fh2.close()
        coordDict = readCoordinates(coordinate)
        
        '''
        coordL = [chr, strand, coord_list] 
        multipleRegionL = [
            [[], []] #coordinate for one matched position
        ]
        '''
        #for fragN, fragS in fragD.items():
        #    label = fragN.split('#')[1]
        #    #print label
        #    refS = transcriptD[label]
        #    coordL = coordDict[label]
        #    multipleRegionL = transferFragToCoord(fragS, refS, coordL)
        from itertools import islice
        onceRead = 10000
        file_not_end = 1
        while file_not_end:
            fragD = {}
            lines = list(islice(fh, onceRead))
            if not lines:
                break
            for line in lines:
                if line[0] == '>':
                    fragN = line[1:-1]
                    fragD[fragN] = []
                else:
                    fragD[fragN].append(line.strip())
            #--------------------------------------------
            #----------------------------------------------------
            for fragN, fragS in fragD.items():
                fragS = ''.join(fragS)
                label = fragN.split('#')[1]
                refS = transcriptD[label]
                coordL = coordDict[label]
                multipleRegionL = transferFragToCoord(fragS, refS, coordL)
                output(fragN, coordL, multipleRegionL, fragS)
        #----------------------------------------------
    #---------above for largeFrag----------------------------

    #-------------END reading file----------
    #----close file handle for files-----
    if file != '-':
        fh.close()
    #-----------end close fh-----------
    if verbose:
        print >>sys.stderr,\
            "--Successful %s" % strftime(timeformat, localtime())
if __name__ == '__main__':
    startTime = strftime(timeformat, localtime())
    main()
    endTime = strftime(timeformat, localtime())
    fh = open('python.log', 'a')
    print >>fh, "%s\n\tRun time : %s - %s " % \
        (' '.join(sys.argv), startTime, endTime)
    fh.close()



